#!/usr/bin/env python3

"""
Copyright 2020 The Magma Authors.

This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# Standard Python and PyPi modules
import argparse
import ipaddress
import iperf3
import logging
import logging.handlers
import multiprocessing
import os
import pyroute2
import socketserver
import subprocess
import sys
import threading
import traceback

# Custom modules
from util.traffic_messages import TrafficServerInstance, TrafficRequest, \
    TrafficRequestType, TrafficResponse, TrafficResponseType, TrafficMessage

'''
Overview of the system
======================

    traffic_server.py                                traffic_util.py

TrafficTestServerDispatcher                            TrafficUtil
           ^|      |                                     |  |
           ||      '---TrafficTestServer...TrafficTest---'  |
           ||                                               |
           ||  (receives incoming connection)               |
           ||                                               |
           |V    (message passing with traffic_messages.py) V
    TrafficTestServer  <---------------------------->  TrafficTest
            |      |                                   ^    |    |
            |      '---TrafficTestDriver...            |    |    |
            |                                          |    |    |
            |  (receives test request and config)      |    |    |
            V                                          |    |    |
    TrafficTestDriver  --------------------------------'    |    |
            |  |                                            |    |
            |  '------------iperf3..........iperf3----------'    |
            |                                                    |
            |  (with coordination from TrafficTest)              |
            V                                                    V
          iperf3  <---generated uplink/downlink traffic---->  iperf3

TrafficTestServerDispatcher starts as an independent process that awaits
incoming connection requests into its serversocket.

When TrafficTest objects generated by TrafficUtil are run, they establish a
request with the target TrafficTestServerDispatcher via TCP and communicate
over sockets, using messages defined in util/traffic_messages.py.

When TrafficTest objects begin a test, the associated TrafficTestServer creates
and spins off a TrafficTestDriver object to handle setup for that instance of
the test. Messages are still passed through TrafficTestServer, which uses
identifiers to determine associations with TrafficTestDriver objects.

Other clients of this system may choose to perform non-testing actions upon
connection, e.g initiating a SHUTDOWN for the entire system via a script.

Hierarchically, TrafficTestServerDispatcher objects run as the main thread.
During testing, multiple TrafficTestServer objects are invoked, one per
connection established; it acts as a monitor for TrafficTestDriver objects,
which physically run the iperf3 servers requested by a TrafficTest. Messages
are only passed between TrafficTestServer and TrafficTest, and from
TrafficTestDriver to TrafficTest. The data packets used to test traffic
throughput passes between the iperf3 instance to simulate a user sending
traffic to and receiving traffic from the Internet.
'''


class TrafficTestServerDispatcher(socketserver.TCPServer):
    ''' Main server dispatcher class for setting up traffic test servers

    Uses forks to handle each incoming connection and establishes a duplex pipe
    with each. This means that each fork has two open connections: one to
    handle the external connection, and one to handle communication with the
    main process (dispatcher).

    Implementation is heavily inspired by socketserver.ForkingTCPServer.

    Note: we choose forking over threading because the iperf3 objects that we
    create with the iperf3 module must run on the main thread, else it will run
    into a seg fault in the iperf3 binary.
    '''
    def __init__(self, daemon, host, port, loglevel):
        ''' Initialize the TrafficTestServerDispatcher

        Args:
            daemon (bool): whether to run in daemon mode (forked)
            host (ipaddress.ip_address): the IP address to bind
            port (int): the port to use
            loglevel (str): the name of the log level to use
        '''
        # If we want to daemonize, then fork the process and run on the child.
        # Just remember that SHUTDOWN can only be initiated locally and via the
        # SHUTDOWN message in util/traffic_messages.py. At the moment, there's
        # no easy way to do that, short of manually setting up the socket in
        # Python and sending the constructed message.
        #
        # In short: daemonize at your own risk!
        if daemon:
            if os.fork():
                sys.exit()

        super(TrafficTestServerDispatcher, self).__init__(
            (host.exploded, port), TrafficTestServer)

        pid = os.getpid()

        # Set up log formatting as chunks, later to be space-joined
        log_format = '%(asctime)s.%(msecs)03.0f [PID %(process)d] ' \
            '%(levelname)s (' + '%s:%s' % self.server_address + \
            '%(remote)s) %(message)s'
        log_time_format = '%Y%m%d-%H%M%S'
        # Set up log handling
        log_handler = logging.handlers.SysLogHandler(address='/dev/log') \
            if daemon else logging.StreamHandler()
        log_handler.setFormatter(logging.Formatter(
            log_format, log_time_format))

        # Set up logging, use a different logger for each dispatcher process
        self._base_logger = logging.getLogger(__name__ + '(%d)' % pid)
        self._base_logger.addHandler(log_handler)
        self._base_logger.setLevel(getattr(logging, loglevel))
        self._logger = logging.LoggerAdapter(self._base_logger, {'remote': ''})

        # Now initialize dispatcher
        self.log.debug('Initializing dispatcher...')
        self.log.log(
            getattr(logging, loglevel), 'Setting log level to: %s', loglevel)
        self._procs = {}  # Process: (Connection, Thread)
        self._procs_lock = threading.RLock()  # Serverside lock for procs map

        # Initialize dispatcher-server messaging objects
        objects = [object() for _ in range(3)]

        # Sent from dispatcher to server
        self._dispatcher_server_terminate_message = id(objects[0])

        # Sent from server to dispatcher
        self._server_dispatcher_shutdown_message = id(objects[1])
        self._server_dispatcher_terminate_ack_message = id(objects[2])

        # Broadcast
        self.log.info(
            'Dispatcher running on %s:%d (%sdaemon mode)',
            host.exploded, port, '' if daemon else 'non-')

    def _handle_server_messaging(self, conn, proc):
        ''' Handle the connection with each forked server on a dedicated
        thread

        Yes, it definitely feels kind of recursive...

        Args:
            conn (multiprocessing.Connection): the dispatcher's end of the pipe
            proc (multiprocessing.Process): the forked server process
        '''
        pid = proc.pid  # PID of the server process
        while (not conn.closed) and conn.readable:
            msg = conn.recv()

            if self.SHUTDOWN == msg:
                self.log.debug('Received SHUTDOWN from %d', pid)
                self.signal_shutdown()
            elif self.TERMINATE_ACK == msg:
                self.log.debug('Received TERMINATE_ACK, terminating %d', pid)
                proc.terminate()
                conn.close()
                with self._procs_lock:
                    del self._procs[proc]
                self.log.debug('Handler process %d terminated', pid)
                return  # Here, the thread joins

    @property
    def SHUTDOWN(self):
        ''' Server => dispatcher message: shutdown the system '''
        return self._server_dispatcher_shutdown_message

    @property
    def TERMINATE(self):
        ''' Dispatcher => server message: prepare for termination '''
        return self._dispatcher_server_terminate_message

    @property
    def TERMINATE_ACK(self):
        ''' Server => dispatcher message: necessary cleanup complete, proceed
        with termination '''
        return self._server_dispatcher_terminate_ack_message

    def finish_request(self, connection, request, client_address):
        ''' Create an instance of RequestHandlerClass, the class for servers,
        using the given connection

        Args:
            connection (multiprocess.Connection): the server's connection in
                a pipe to the dispatcher
            request: inbound request object
            client_address: (ip, port) pair of the remote connector
        '''
        self.RequestHandlerClass(connection, request, client_address, self)
        self.log.debug('Ended connection with %s:%s' % client_address)

        # Formally close the fork and connection handling thread
        if (not connection.closed) and connection.writable:
            connection.send(self.TERMINATE_ACK)

        os._exit(os.EX_OK)

    def get_server_loggers(self, remote_ip, remote_port):
        ''' Create a tuple of logger adapaters based on the dispatcher's logger
        configuration and the server's connection

        Args:
            remote_ip (str): the remote client's IP address
            remote_port (str): the remote client's port number

        Returns a Logger object
        '''
        # General-purpose log, to indicate the server and connection
        log = logging.LoggerAdapter(
            self._base_logger,
            {'remote': '--%s:%s' % (remote_ip, remote_port)})

        # Logging inbound messages
        login = logging.LoggerAdapter(
            self._base_logger,
            {'remote': '<=%s:%s' % (remote_ip, remote_port)})

        # Logging outbound messages
        logout = logging.LoggerAdapter(
            self._base_logger,
            {'remote': '=>%s:%s' % (remote_ip, remote_port)})

        return log, login, logout

    @property
    def log(self):
        ''' Access the logger '''
        return self._logger

    def process_request(self, request, client_address):
        ''' Fork off a server to serve the request '''
        dconn, sconn = multiprocessing.Pipe()
        proc = multiprocessing.Process(
            target=self.finish_request, args=(sconn, request, client_address))
        thread = threading.Thread(
            target=self._handle_server_messaging, args=(dconn, proc))

        proc_tuple = (dconn, thread)
        with self._procs_lock:
            self._procs[proc] = proc_tuple

        proc.start()
        thread.start()

    def signal_shutdown(self):
        ''' Send the signal for all forks to end

        The shutdown procedure is as follows:
            - A server receives a SHUTDOWN message from a client
            - The server sends the SHUTDOWN message to the dispatcher
            - The dispatcher calls this function to broadcast a TERMINATE to
              all the servers
            - All servers respond with TERMINATE_ACK when ready
            - Upon receiving a TERMINATE_ACK, the dispatcher calls terminate()
              on its Process
            - Once all Processes are terminated, the dispatcher exits

        The goal of the TERMINATE and TERMINATE_ACK procedure is to give the
        servers a chance to do any cleanup, e.g. removing zombified children.
        At the moment, however, this does not actually happen, as the
        TrafficTestServer implementation simply responds with TERMINATE_ACK as
        soon as it receives a TERMINATE.
        '''
        with self._procs_lock:
            self.log.info('Received signal to shut down dispatcher')
            for proc in self._procs:
                conn, _ = self._procs[proc]
                if (not conn.closed) and conn.writable:
                    self.log.debug('Sending TERMINATE to %d', proc.pid)
                    conn.send(self.TERMINATE)

        self.log.debug('Sent TERMINATE to all servers')
        self.shutdown()
        self.log.info('Dispatcher is shutting down!')


class TrafficTestServer(socketserver.StreamRequestHandler):
    ''' Class for handling each request from the testing framework, or possibly
    other sources who access the dispatcher

    Defines message handling for each external connection. Clients can send
    messages to set up test iperf3 instances, send configuration details, and
    retrieve test statistics results.

    Also handles communication with the dispatcher in a two-way manner, thus
    allowing for message passing between the dispatcher and each server for
    events like shutdown.
    '''
    def __init__(self, connection, *args):
        ''' Create a server with some additional arguments

        Args:
            connection (multiprocessing.Connection): the server's end of the
                pipe
            args (list(...)): the args to pass into the superclass constructor
        '''
        self._conn = connection
        super(TrafficTestServer, self).__init__(*args)

    def _handle_dispatcher_messaging(self):
        ''' Handle communication with the parent dispatcher process on a
        dedicated thread '''
        while (not self._conn.closed) and self._conn.readable:
            msg = self._conn.recv()

            if self.dispatcher.TERMINATE == msg:
                self.log.debug('Responding to TERMINATE with TERMINATE_ACK')
                self._conn.send(self.dispatcher.TERMINATE_ACK)
                return  # Here, the thread joins

    def handle(self):
        ''' Handles communications with the client, overrides base class
        implementation of the function '''
        while (not self.rfile.closed) and self.rfile.readable:
            msg, identifier, payload = self.recv_req()

            if msg is TrafficRequestType.EXIT:
                self.log.debug('Ending connection due to EXIT message')
                return

            elif msg is TrafficRequestType.SHUTDOWN:
                server_ip = self.dispatcher.server_address[0]
                client_ip = self.client_address[0]
                if server_ip == client_ip:  # Only allow shutdown from local IP
                    self.log.info('Initiating server SHUTDOWN')
                    self._conn.send(self.dispatcher.SHUTDOWN)
                else:
                    payload = 'SHUTDOWN request of %s from %s rejected;' \
                        ' client must have the same IP as the server to' \
                        ' initiate SHUTDOWN procedure'
                    self.send_resp(
                        TrafficRequestType.INFO, payload=payload)

            elif msg is TrafficRequestType.START:
                self.log.debug('Waiting to store START message')
                with self._start_msgs_cond:
                    self._start_msgs |= {identifier}
                    self.log.debug('Stored START message, notifying drivers')
                    self._start_msgs_cond.notify_all()

            elif msg is TrafficRequestType.TEST:
                try:
                    driver = TrafficTestDriver(self, payload)
                    threading.Thread(target=driver.run).start()
                except Exception as e:
                    self.log.error(''.join(traceback.format_exception(
                        type(e), e, sys.exc_info()[2])))

            else:
                self.log.warning(
                    'Message of type %s received but not recognized.'
                    ' Perhaps the system needs a reboot?', msg.name)

        # Ending due to rfile abruptly closing
        self.log.warning('Ending handling due to abruptly closed connection')

    @property
    def log(self):
        ''' Access the logger '''
        return self._logger

    def recv_req(self, level=logging.DEBUG):
        ''' Receive an object from the client over the socket connection

        Args:
            level (int): the logging level to use; defaults to logging.DEBUG

        Return message, payload tuple extracted from the message
        '''
        msg = TrafficMessage.recv(self.rfile)
        if not msg:  # Empty message -- treat as EXIT request
            self._login.warning('Received empty message, treating as EXIT')
            msg = TrafficRequest(TrafficRequestType.EXIT)
        self._login.log(level, msg)
        return msg.message, msg.id, msg.payload

    def send_resp(self, message, id=None, payload=None, level=logging.DEBUG):
        ''' Send a message to the client over the socket connection

        Args:
            message (TrafficResponseType): the message type
            id (int): a unique identifier, if needed; defaults to None
            payload (object): the payload object, must be picklable; defaults
                to None
            level (int): the logging level to use; defaults to logging.DEBUG
        '''
        msg = TrafficResponse(message, id, payload)
        self._logout.log(level, msg)
        msg.send(self.wfile)

    def setup(self):
        ''' Called during construction before launching into handling loop '''
        super(TrafficTestServer, self).setup()

        # Alias to be consistent with our abstractions
        self.dispatcher = self.server

        # Set up logging
        self._logger, self._login, self._logout = \
            self.dispatcher.get_server_loggers(*self.client_address)

        self.log.info('Connection established')

        # Initialize server
        self._start_msgs = set()
        self._start_msgs_cond = threading.Condition()

        # Set up listener for communication with the dispatcher
        threading.Thread(target=self._handle_dispatcher_messaging).start()

    def wait_for_start_msg(self, driver):
        ''' Wait for the start message associated with the given driver from
        the start message holding structure

        This is a blocking call that will not return until the start message
        appears

        Args:
            driver (TrafficTestDriver): the driver to which the START message
                should have an identifier
        '''
        d_id = id(driver)  # Used as the key in the holding structure

        def predicate():
            return d_id in self._start_msgs

        with self._start_msgs_cond:
            self.log.debug('Waiting on START message for %d', d_id)
            self._start_msgs_cond.wait_for(predicate)
            self.log.debug('START message for %d received', d_id)

            # Proceed; now remove the indicator
            self._start_msgs -= {d_id}

            # Wake up other threads if there are still messages to process
            if len(self._start_msgs):
                self.log.debug('There are more messages, notifying all again')
                self._start_msgs_cond.notify_all()

        self.log.debug('Driver %d successfully received START', d_id)


class TrafficTestDriver(object):
    ''' Driver for creating the iperf3 instances, monitoring them, caching the
    results, and reporting the results '''

    _port = 5000
    _port_lock = threading.Lock()

    def __init__(self, server, instances):
        ''' Creates a driver, ready to launch the test setup

        Args:
            server (TrafficTestServer): the server for the connection that
                requested the test
            instances (list(TrafficTestInstance)): the test instance
                configuration objects that have been send over the wire
        '''
        # Party is 1 per instance + main thread
        self._barrier = threading.Barrier(1 + len(instances))
        self._server = server
        self._instances = instances
        self._results = None

        self._setup_iperf3()

    def _get_macs(self):
        ''' Retrieves the MAC addresses of the associated test servers, based
        on the information of the instances '''
        ip = pyroute2.IPRoute()
        mac = ip.link('get', index=ip.link_lookup(ifname='eth2')[0])[0] \
            .get_attr('IFLA_ADDRESS')

        return (mac,) * len(self._instances)

    @staticmethod
    def _get_port():
        ''' Returns the next port for testing '''
        with TrafficTestDriver._port_lock:
            TrafficTestDriver._port += 1
            return TrafficTestDriver._port

    def _run_iperf3(self, results_buffer, iperf):
        ''' Runs the given iperf3 and writes the results into the given list
        buffer by appending the result to it

        Note: this driver function is specifically tailored to the purposes of
        traffic testing. Key differences from calling IPerf3.run():
            - only uses the IPerf3 objects as containers
            - uses subprocess to run iperf3 with the given arguments
            - as a result, only picks some arguments that are expected to exist
              and ignores others
        This is done to enable multiple instances to run concurrently in
        threads without worrying about race conditions on stdout, as the iperf3
        module's run() will reroute stdout to a private pipe. To avoid this, we
        use subprocess to give each iperf run its own private pipe. The only
        other option is to fork each child iperf into its own subprocess, which
        is not too different from this implementation.

        Args:
            results_buffer (list): the results buffer to which the
                iperf3.TestResult object to should be appended
            iperf (iperf3.IPerf3): the iperf3 object to run
        '''
        # Constructing the subprocess call
        params = ('-B', iperf.bind_address, '-p', str(iperf.port), '-J')
        if 'c' == iperf.role:
            params = ('-c', iperf.server_hostname) + params
            params += ('-b', str(iperf.bandwidth), '-t', str(iperf.duration))
            if 'udp' == iperf.protocol:
                params += ('-u',)
        else:
            params = ('-s',) + params
            params += ('-1',)
        params = ('iperf3',) + params

        # Make the iperf3 call and spin off the subprocess
        self._server.log.debug('Running iperf3 command: %s', ' '.join(params))
        with subprocess.Popen(params, stdout=subprocess.PIPE) as proc:
            result_str = proc.stdout.read().decode('utf-8')
            results_buffer += [iperf3.TestResult(result_str)]
        self._barrier.wait()

    def _setup_iperf3(self):
        ''' Set up the iperf3 servers for the test instances. Must be run on a
        main thread '''
        self._iperfs = ()
        for instance in self._instances:
            if instance.is_uplink:
                iperf = iperf3.Server()
                iperf.bind_address = '192.168.129.42'
                iperf.port = TrafficTestDriver._get_port()
            else:
                iperf = iperf3.Client()
                iperf.bandwidth = 10 ** 7  # 10 Mbps
                iperf.bind_address = '192.168.129.42'
                iperf.duration = instance.duration
                iperf.port = instance.port
                iperf.protocol = 'udp' if instance.is_udp else 'tcp'
                iperf.server_hostname = instance.ip.exploded
            self._iperfs += (iperf,)

    @property
    def results(self):
        ''' Retrieve the results for this driver; remains None until the tuple
        of iperf3.TestResult objects has been generated by the driver '''
        return self._results

    def run(self):
        ''' Runs the test (servers), retrieves results, caches them, and
        transmits them back to the client '''
        self._barrier.reset()

        ips = (
            ipaddress.ip_address(iperf.bind_address) for iperf in self._iperfs)
        ports = (
            iperf.port if 's' == iperf.role else 0 for iperf in self._iperfs)
        macs = self._get_macs()

        # Reshape into argument tuples
        tuples = zip(ips, ports, macs)

        # Create the TrafficResponse, Server type message
        server_instances = [
            TrafficServerInstance(*args)
            for args in tuples]
        self._server.send_resp(TrafficResponseType.SERVER, id=id(self),
                               payload=server_instances)

        # Now wait for START message
        self._server.wait_for_start_msg(self)

        self._server.log.debug(
            'Ready to start iperf3 servers for driver %d', id(self))
        results = ()
        threads = ()
        for iperf in self._iperfs:
            buf = []
            thread = threading.Thread(
                target=self._run_iperf3, args=(buf, iperf))
            results += (buf,)
            threads += (thread,)
            thread.start()
        self._server.log.debug('Driver %d has started %d iperf3 servers',
                                id(self), len(self._iperfs))

        # Send STARTED message to let client know that we've spun everything up
        self._server.send_resp(
            TrafficResponseType.STARTED, id=id(self))

        # Now wait for the threads to hit the barrier and join
        self._barrier.wait()
        for thread in threads:
            thread.join()
        self._server.log.debug(
            'Driver %d has joined its iperf3 servers', id(self))

        # Convert from mutable list to an immutable tuple for storing results
        self._results = tuple(sum(results, []))

        # Create and send the RESULTS message
        self._server.send_resp(TrafficResponseType.RESULTS, id=id(self),
                               payload=self._results)

        self._server.log.debug(
            'Driver %d has cached and transmitted results', id(self))

        # Driver thread joins here, resources should be freed after this


def main():
    parser = argparse.ArgumentParser()

    # Whether to daemonize
    parser.add_argument('-d', '--daemon', action='store_true', default=False,
                        help='Specify to run server as daemon. Default: False')

    # The IP address to bind to
    parser.add_argument('host', default=ipaddress.ip_address('127.0.0.1'),
                        nargs='?', type=ipaddress.ip_address,
                        help='Specify IPv4/6 bind address. Default: 127.0.0.1')

    # The level of logging
    log_level_choices = (
        'NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL')
    parser.add_argument('-L', '--log-level', choices=log_level_choices,
                        default='DEBUG',
                        help='Specify a log level for the system')

    # 62462 is 'MAGMA' in telephone keypad format
    parser.add_argument('port', default=62462, nargs='?', type=int,
                        help='Specify alternative port. Default: 62462')

    args = parser.parse_args()

    # Start up dispatcher and run it
    dispatcher = TrafficTestServerDispatcher(
        args.daemon, args.host, args.port, args.log_level)
    dispatcher.serve_forever()


if '__main__' == __name__:
    main()
